/*
 * ReferenceElement.scala
 * Elements representing references and aggregates over references.
 *
 * Created By:      Avi Pfeffer (apfeffer@cra.com)
 * Creation Date:   Jan 1, 2009
 *
 * Copyright 2017 Avrom J. Pfeffer and Charles River Analytics, Inc.
 * See http://www.cra.com or email figaro@cra.com for information.
 *
 * See http://www.github.com/p2t2/figaro for a copy of the software license.
 */

package com.cra.figaro.language

import com.cra.figaro.algorithm.ValuesMaker
import com.cra.figaro.algorithm.lazyfactored._
import ValueSet._
import com.cra.figaro.algorithm.factored.factors._
import com.cra.figaro.algorithm.factored.factors.factory.Factory
import com.cra.figaro.util._
import scala.collection.mutable.Map
import scala.language.existentials

/**
 * Element representing the value of a reference.
 *
 * @param collection The collection to use to resolve the reference.
 * @param reference The reference whose value is represented by this element.
 *
 */

abstract class ReferenceElement[T, U](coll: ElementCollection, val reference: Reference[T])
  extends Deterministic[U]("", coll) with IfArgsCacheable[U] {
  lazy val args = collection.makeArgs(reference).toList
}

/**
 * Element representing a single-valued reference. Its value in a state is generated by following the reference
 * in the state and taking the value of the resulting element.
 */
class SingleValuedReferenceElement[T](collection: ElementCollection, reference: Reference[T])
  extends ReferenceElement[T, T](collection, reference) with ValuesMaker[T] {
  def generateValue() = {
    val referredToElement = collection.getElementByReference(reference)
    referredToElement.generateValue(referredToElement.randomness)
    referredToElement.value
  }

  /*
   * We need to make sure that if the reference is indirect, each of the reference elements embedded in this element have their values
   * determined, and then the same embedded elements are used when making factors. This is achieved by storing the embedded elements
   * in a map.
   * Note that we could have achieved a similar effect by making SingleValuedReferenceElement a case class, but that would be incorrect
   * since with reference uncertainty, the same reference may exist twice in a collection with two different embedded references.
   */
  val embeddedElements: Map[ElementCollection, SingleValuedReferenceElement[T]] = Map()

  /**
   * Returns all possible values of the given reference.
   */
  private def referenceValues(reference: Reference[T], depth: Int): ValueSet[T] = {
    collection.getFirst(reference) match {
      case (elem, None) => LazyValues(universe)(elem.asInstanceOf[Element[T]], depth)
      case (elem, Some(rest)) =>
        val firstValues = LazyValues(universe)(elem.asInstanceOf[Element[ElementCollection]], depth)
        val restValues =
          for { first <- firstValues.regularValues } yield {
            val embedded = new SingleValuedReferenceElement(first, rest)
            embeddedElements += first -> embedded
            LazyValues(universe)(embedded, depth - 1)
          }
        val startValues: ValueSet[T] = if (!firstValues.hasStar) ValueSet.withoutStar(Set()); else ValueSet.withStar(Set())
        restValues.foldLeft(startValues)(_ ++ _)
    }
  }
  def makeValues(depth: Int): ValueSet[T] = {
    referenceValues(reference, depth)
  }

  //  def makeFactors: List[Factor[Double]] = {
  //    val (first, rest) = collection.getFirst(reference)
  //    rest match {
  //      case None =>
  //        val thisVar = Variable(this)
  //        val refVar = Variable(first)
  //        val factor = Factory.make[Double](List(thisVar, refVar))
  //        for {
  //          i <- 0 until refVar.range.size
  //          j <- 0 until refVar.range.size
  //        } {
  //          factor.set(List(i, j), (if (i == j) 1.0; else 0.0))
  //        }
  //        List(factor)
  //      case Some(restRef) =>
  //        val firstVar = Variable(first)
  //        val selectedFactors =
  //          for {
  //            (firstXvalue, firstIndex) <- firstVar.range.zipWithIndex
  //            firstCollection = firstXvalue.value.asInstanceOf[ElementCollection]
  //            restElement = embeddedElements(firstCollection)
  //          } yield {
  //            Factory.makeConditionalSelector(this, firstVar, firstIndex, Variable(restElement)) :: restElement.makeFactors
  //          }
  //        selectedFactors.flatten
  //    }
  //  }
}

/**
 * Aggregate elements based on multi-valued references. Note that the values aggregated over are all the values of
 * all elements that are referred to by the reference. If the same element is reachable by more than one path in the
 * reference, its value is only included once. However, if two different referred to elements have the same value,
 * the value is included multiple times. Since the order of these values is immaterial, we use a multiset to represent them.
 *
 * @param aggregate A function to aggregate elements referred to by the reference into a value of this aggregate element.
 */
class Aggregate[T, U](collection: ElementCollection, reference: Reference[T], val aggregate: MultiSet[T] => U)
  extends ReferenceElement[T, U](collection, reference) with ValuesMaker[U] {

  val mvre = new MultiValuedReferenceElement(collection, reference)
  private def possibleInputs(depth: Int): ValueSet[MultiSet[T]] =
    LazyValues(universe)(mvre, depth)

  def generateValue(): U = aggregate(HashMultiSet(collection.getManyElementsByReference(reference).toList.map(_.value): _*))

  def makeValues(depth: Int): ValueSet[U] = {
    val inputs = possibleInputs(depth)
    val resultValues = inputs.regularValues.map(aggregate(_))
    if (inputs.hasStar) withStar(resultValues); else withoutStar(resultValues)
  }

//  def makeFactors = {
//    val thisVar = Variable(this)
//    val mvreVar = Variable(mvre)
//    val factor = Factory.make[Double](List(thisVar, mvreVar))
//    for {
//      (thisXvalue, thisIndex) <- thisVar.range.zipWithIndex
//      (mvreXvalue, mvreIndex) <- mvreVar.range.zipWithIndex
//    } {
//      if (thisXvalue.isRegular && mvreXvalue.isRegular) factor.set(List(thisIndex, mvreIndex), if (aggregate(mvreXvalue.value) == thisXvalue.value) 1.0; else 0.0)
//    }
//    // The MultiValuedReferenceElement for this aggregate is generated when values is called.
//    // Therefore, it will be included in the expansion and have factors made for it automatically, so we do not create factors for it here.
//    List(factor)
//  }
}

/**
 * Element representing the values of a reference that can have multiple values.
 *
 * @param collection The collection to use to resolve the reference.
 * @param reference The reference whose value is represented by this element.
 *
 */

class MultiValuedReferenceElement[T](coll: ElementCollection, ref: Reference[T]) extends ReferenceElement[T, MultiSet[T]](coll, ref)
  with ValuesMaker[MultiSet[T]] {

  /*
     * We need to make sure that if the reference is indirect, each of the reference elements embedded in this element have their values
     * determined, and then the same embedded elements are used when making factors. This is achieved by storing the embedded elements
     * in a map.
     * Note that we could have achieved a similar effect by making MultiValuedReferenceElement a case class, but that would be incorrect
     * since with reference uncertainty, the same reference may exist twice in a collection with two different embedded references.
     *
     * Since the MVRE factor maker uses embedded Inject and Apply elements, we need to make sure they are expanded and have their
     * values computed when makeValues is called. Therefore, we store them in embeddedInject and embeddedApply. The key to these maps
     * is the list of element collections that represents a possible value of the head of the reference.
     */
  val embeddedElements: Map[ElementCollection, MultiValuedReferenceElement[T]] = Map()
  val embeddedInject: Map[List[ElementCollection], Element[List[MultiSet[T]]]] = Map()
  val embeddedApply: Map[List[ElementCollection], Element[MultiSet[T]]] = Map()

  // collection.getElements is a set of elements, because if the same element is reachable by more than one path, it is only counted once.
  // We convert it to a list so we get all the values of these elements, even if some of the elements have the same values.
  def generateValue(): MultiSet[T] = {
    val referredToElements = collection.getManyElementsByReference(reference).toList
    referredToElements.foreach(elem => elem.generateValue(elem.randomness))
    HashMultiSet(referredToElements: _*) map ((e: Element[T]) => e.value)
  }

  private def allUnions(setset1: ValueSet[MultiSet[T]], setset2: ValueSet[MultiSet[T]]): ValueSet[MultiSet[T]] = {
    val multiSets = for { set1 <- setset1.regularValues; set2 <- setset2.regularValues } yield set1 union set2
    if (setset1.hasStar || setset2.hasStar) withStar(multiSets); else withoutStar(multiSets)
  }

  /*
     * The value of a MultiValuedReferenceElement is a multiset.
     * To find all possible values, we proceed as follows:
     * If the reference is simple (i.e., just a name), the element referred to by that name has some set of values.
     * Each of those values becomes a singleton multiset in the values of the MVRE.
     * We get a value set of multisets for each possible value of the name.
     * The final step is to take the value set union of these value sets and return it as the value set for this MVRE.
     *
     * If the reference is compound, the each of the first name's possible values may either be a single EC or a traversable of ECs.
     * If it's just an EC, we get the values of the rest of the reference (which are multisets)
     * and those become the possible values of the MVRE associated with this first value.
     * If it's a traversable of ECs, we get the possible values of the rest of the reference for each of the ECs in the traversable.
     * Each such value is a multiset.
     * So, at this point, we have a traversable T of value sets of multisets. We need to convert it into a single values set
     * of multisets such that each multiset in the resulting value set is the multiset union of one multiset chosen for
     * each of the value sets in T. This is accomplished by allUnions.
     * As a result, we get a value set of multisets for each possible value of the name.
     * The final step is to take the value set union of these value sets and return it as the value set for this MVRE.
     */
  def makeValues(depth: Int): ValueSet[MultiSet[T]] = {
    val (first, rest) = collection.getFirst(reference)
    rest match {
      case None =>
        val firstValues = LazyValues(universe)(first.asInstanceOf[Element[T]], depth)
        firstValues.map((t: T) => HashMultiSet(List(t): _*))
      case Some(restRef) =>
        val results: Set[ValueSet[MultiSet[T]]] = {
          val firstValues: ValueSet[_] = LazyValues(universe)(first, depth)
          for { firstXvalue <- firstValues.xvalues } yield {
            if (firstXvalue.isRegular) {
              firstXvalue.value match {
                case firstColl: ElementCollection =>
                  val embedded = new MultiValuedReferenceElement(firstColl, restRef)
                  embeddedElements += firstColl -> embedded
                  LazyValues(universe)(embedded, depth - 1)
                case ecs: Traversable[_] =>
                  val collections: List[ElementCollection] = ecs.map((x: Any) => x.asInstanceOf[ElementCollection]).toList.distinct
                  val multis = {
                    for {
                      firstColl <- collections // Aggregates use set semantics for the elements they use. If the same element appears more than once, it is only counted once.
                    } yield {
                      val restMulti = new MultiValuedReferenceElement(firstColl, restRef)
                      embeddedElements += firstColl -> restMulti
                      restMulti
                    }
                  }
                  val combination: Element[List[MultiSet[T]]] = Inject(multis: _*)
                  LazyValues(universe)(combination, depth - 1)
                  embeddedInject += collections -> combination
                  val applyStarter: MultiSet[T] = HashMultiSet[T]()
                  val setMaker = Apply(combination, (sets: List[MultiSet[T]]) => (applyStarter /: sets)(_ union _))
                  LazyValues(universe)(setMaker, depth - 1)
                  embeddedApply += collections -> setMaker
                  val resultSets: List[ValueSet[MultiSet[T]]] =
                    for {
                      (firstColl, multi) <- collections.zip(multis) // Aggregates use set semantics for the elements they use. If the same element appears more than once, it is only counted once.
                    } yield {
                      LazyValues(universe)(multi, depth - 1)
                    }
                  /*
                     * Here, we are creating the value set of multisets in which each multiset is the multiset union of
                     * multisets in the result sets. Since we are taking the multiset union, we start with the empty
                     * multiset, which is why the start is the value set containing the single empty multiset.
                     */
                  val starter: ValueSet[MultiSet[T]] = withoutStar(Set(HashMultiSet()))
                  (starter /: resultSets)(allUnions(_, _))
              }
            } else withStar[MultiSet[T]](Set())
          }
        }
        /*
         * Here we are taking the value set union of the value sets produced for each of the possible first values,
         * so that starter is the empty value set.
         */
        val starter: ValueSet[MultiSet[T]] = withoutStar(Set())
        (starter /: results)(_ ++ _)
    }
  }

  //    def makeFactors: List[Factor[Double]] = {
  //      val (first, rest) = collection.getFirst(reference)
  //      val selectionFactors: List[List[Factor[Double]]] = {
  //        rest match {
  //          case None =>
  //            val thisVar = Variable(this)
  //            val refVar = Variable(first)
  //            val factor = Factory.make[Double](List(thisVar, refVar))
  //            for {
  //              i <- 0 until refVar.range.size
  //              j <- 0 until refVar.range.size
  //            } {
  //              factor.set(List(i, j), (if (i == j) 1.0; else 0.0))
  //            }
  //            List(List(factor))
  //          case Some(restRef) =>
  //            val firstVar = Variable(first)
  //            for {
  //              (firstXvalue, firstIndex) <- firstVar.range.zipWithIndex
  //            } yield {
  //              if (firstXvalue.isRegular) {
  //                firstXvalue.value match {
  //                  case firstCollection: ElementCollection =>
  //                    val restElement = embeddedElements(firstCollection)
  //                    val result: List[Factor[Double]] =
  //                      Factory.makeConditionalSelector(this, firstVar, firstIndex, Variable(restElement)) :: Factory.make(restElement)
  //                    result
  //                  case cs: Traversable[_] =>
  //                    // Create a multi-valued reference element (MVRE) for each collection in the value of the first name.
  //                    // Since the first name is multi-valued, its value is the union of the values of all these MVREs.
  //                    val collections = cs.asInstanceOf[Traversable[ElementCollection]].toList.distinct // Set semantics
  //                    val multis: List[MultiValuedReferenceElement[T]] = collections.map(embeddedElements(_)).toList
  //                    // Create the element that takes the union of the values of the all the MVREs.
  //                    // The combination and setMaker elements are encapsulated within this object and are created now, so we need to create factors for them.
  //                    // Finally, we create a conditional selector (see ProbFactor) to select the appropriate result value when the first
  //                    // name's value is these MVREs.
  //                    val combination = embeddedInject(collections)
  //                    val setMaker = embeddedApply(collections)
  //                    val result: List[Factor[Double]] =
  //                      Factory.makeConditionalSelector(this, firstVar, firstIndex, Variable(setMaker)) :: Factory.make(combination) :::
  //                        Factory.make(setMaker)
  //                    result
  //                }
  //              } else Factory.makeStarFactor(this)
  //            }
  //        }
  //      }
  //      selectionFactors.flatten
  //    }
}
